# -*- coding: utf-8 -*-
#
# Copyright 2016 Ricequant, Inc
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import pytz
import six
import redis

import pandas as pd
# from ..instruments import Instrument
from ...RunQuant.instruments import Instrument
from ..RedisCMG.redis_config import RedisConfig
from ..RedisCMG.redis_data import RedisDataReader
from ..const import DATABASES, REDIS


class LocalDataSource(object):
    DAILY = 'st_stock_days.bcolz'
    INSTRUMENTS = 'instruments.pk'
    DIVIDEND = 'original_dividends.bcolz'
    TRADING_DATES = 'trading_dates.bcolz'
    YIELD_CURVE = 'yield_curve.bcolz'

    YIELD_CURVE_TENORS = {
        0: 'S0',
        30: 'M1',
        60: 'M2',
        90: 'M3',
        180: 'M6',
        270: 'M9',
        365: 'Y1',
        365 * 2: 'Y2',
        365 * 3: 'Y3',
        365 * 4: 'Y4',
        365 * 5: 'Y5',
        365 * 6: 'Y6',
        365 * 7: 'Y7',
        365 * 8: 'Y8',
        365 * 9: 'Y9',
        365 * 10: 'Y10',
        365 * 15: 'Y15',
        365 * 20: 'Y20',
        365 * 30: 'Y30',
        365 * 40: 'Y40',
        365 * 50: 'Y50',
    }

    YIELD_CURVE_DURATION = sorted(YIELD_CURVE_TENORS.keys())

    PRICE_SCALE = 1000.

    def __init__(self, root_dir):
        self._root_dir = root_dir
        import bcolz
        import os
        import MySQLdb

        self._conn = MySQLdb.connect(host=DATABASES["hqdata"].get("host"), user=DATABASES["hqdata"].get("user"),
                                     passwd=DATABASES["hqdata"].get("password"), db=DATABASES["hqdata"].get("db"),
                                     charset="utf8")
        self._redisConfig = RedisConfig(host=REDIS["hqdata"].get("host"), port=REDIS["hqdata"].get("port", 6379), db=0,
                                        password=REDIS["hqdata"].get("password", "redisserver"))
        self._redisDataReader = RedisDataReader(self._redisConfig)
        self._instruments = {d['order_book_id']: Instrument(d)
                             for d in self.__getAllInstruments()}
        self._dividend = bcolz.open(os.path.join(root_dir, LocalDataSource.DIVIDEND))
        self._yield_curve = bcolz.open(os.path.join(root_dir, LocalDataSource.YIELD_CURVE))
        self._trading_dates = self._redisDataReader.getTradeDate()

    # 获取吗表信息
    def instruments(self, order_book_ids):
        if isinstance(order_book_ids, six.string_types):
            try:
                return self._instruments[order_book_ids]
            except KeyError:
                print('ERROR: order_book_id {} not exists!'.format(order_book_ids))
                return None

        return [self._instruments[ob] for ob in order_book_ids
                if ob in self._instruments]

    def all_instruments(self):
        return pd.DataFrame([v.__dict__ for v in self._instruments.values()])

    # 获取沪深A股的码表信息
    def __getAllInstruments(self):
        import json
        strSql = 'select concat(code,".",market) as order_book_id,name as symbol,market,industry,area,pe,outstanding,totals,totalAssets,liquidAssets,reserved,reservedPerShare,\
        esp,bvps,pb,timeToMarket from stock_baseinfo'
        df = pd.read_sql(strSql, self._conn)
        instruments = json.loads(df.reset_index().to_json(orient='records'))
        return instruments

    def sector(self, code):
        return [v.order_book_id for v in self._instruments.values()
                if v.type == 'CS' and v.sector_code == code]

    def industry(self, code):
        return [v.order_book_id for v in self._instruments.values()
                if v.type == 'CS' and v.industry_code == code]

    def concept(self, *concepts):
        return [v.order_book_id for v in self._instruments.values()
                if v.type == 'CS' and any(c in v.concept_names.split('|') for c in concepts)]

    '''
    @function: 获取交易日期，自动跳过不开盘的日期
    @param:start_date: 开始日期
    @param:end_date: 结束日期
    '''

    def get_trading_dates(self, start_date, end_date):
        left = self._trading_dates.searchsorted(start_date)
        right = self._trading_dates.searchsorted(end_date, side='right')
        return self._trading_dates[left:right]

    '''
    @function:获取分钟数据
    @param:start_date: 开始日期
    @param:end_date: 结束日期
    '''

    def get_trading_minutes(self, start_time, end_time):
        left_date = self._trading_dates.searchsorted(start_time)
        right_date = self._trading_dates.searchsorted(end_time, side='right')
        if left_date == right_date:
            trade_dates = self._trading_dates[left_date:right_date + 1]
        else:
            trade_dates = self._trading_dates[left_date:right_date]
        if len(trade_dates) == 0:
            raise RuntimeError('the length of trade data is 0')
        minutes = pd.DatetimeIndex([])
        for trade_date in trade_dates:
            periord_begin_1 = str(trade_date) + ' 09:31:00';
            periodr_end_1 = str(trade_date) + ' 11:30:00';
            periord_begin_2 = str(trade_date) + ' 13:00:00';
            periodr_end_2 = str(trade_date) + ' 15:00:00';
            data1 = pd.date_range(periord_begin_1, periodr_end_1, freq='T')
            data2 = pd.date_range(periord_begin_2, periodr_end_2, freq='T')
            minutes = minutes.append(data1)
            minutes = minutes.append(data2)
        left_minutes = minutes.searchsorted(start_time)
        right_minutes = minutes.searchsorted(end_time, side='right')
        ret = minutes[left_minutes:right_minutes]
        return ret

    def get_yield_curve(self, start_date, end_date):
        duration = (end_date - start_date).days
        tenor = 0
        for t in LocalDataSource.YIELD_CURVE_DURATION:
            if duration >= t:
                tenor = t
            else:
                break
        d = start_date.year * 10000 + start_date.month * 100 + start_date.day
        print d
        print type(self._yield_curve)
        df = self._yield_curve.fetchwhere('date<={}'.format(d)).cols[self.YIELD_CURVE_TENORS[tenor]][-1] / 10000.0
        print df
        return df

    def get_dividends(self, order_book_id):
        df = None
        try:
            strSql = 'select code as order_book_id,date as ex_dividend_date,dividendCash,10 as round_lot from stock_divident_factor where divitype = 1'
            df = pd.read_sql(strSql, self._conn)
        except:
            raise RuntimeError("Mysql operator err")
        return df

    def get_all_bars(self, order_book_id):
        bars = self._redisDataReader.getAllKlineData(order_book_id, '1d', True)
        return bars

    def get_minute_bars(self, order_book_id, frequency):
        bars = self._redisDataReader.getAllKlineData(order_book_id, frequency)
        return bars
